# replication code for
# 'Using General Messages to Persuade on a Politicized Scientific Issue'
library(tidyverse)
library(data.table)
library(dtplyr)
library(colorblindr)
library(here)

theme_jg <- function(){
  theme_classic()+
    theme(text = element_text(family = "serif", size = 16),
          strip.text = element_text(face = "bold"),
          plot.title = element_text(face = "bold", size = 20))
}

here::here()
source("code/vaccine_experiments_functions.R")

library(srvyr)
library(grf)
library(multcomp)
library(ggpattern)
library(geomnet)
library(cowplot)
draw_label_theme <- function(label, theme = NULL, element = "text", ...) {
  if (is.null(theme)) {
    theme <- ggplot2::theme_get()
  }
  if (!element %in% names(theme)) {
    stop("Element must be a valid ggplot theme element name")
  }
  
  elements <- ggplot2::calc_element(element, theme)
  
  cowplot::draw_label(label, 
                      fontfamily = elements$family,
                      fontface = elements$face,
                      colour = elements$color,
                      size = elements$size,
                      ...
  )
}


wave14.c.exp.oh <- data.table::fread("data/w14_vac_message_cleaned_onehot.csv")
# tests for whether ATEs are themselves different
wave14.c.exp.oh$vac_ex_mes_names <- factor(make.names(wave14.c.exp.oh$vac_ex_mes_trunc))
wave14.c.exp.oh$vac_message_resistant <- as.numeric(wave14.c.exp.oh$vac_message == 1)

# ATEs, adjusting for multiple comparisons
likelihood_comparisons <- 
  aov(vac_message ~ vac_ex_mes_names, 
      data = as_tibble(wave14.c.exp.oh))

likelihood_sum_mc <- 
  summary(glht(likelihood_comparisons))
likelihood_sum_mc

resistance_comparisons <- 
  aov(vac_message_resistant ~ vac_ex_mes_names, 
      data = as_tibble(wave14.c.exp.oh))

resistance_sum_mc <- 
  summary(glht(resistance_comparisons))
resistance_sum_mc

# put in table
ate_table <- data.frame(resistance_coefs = resistance_sum_mc$test$coefficients,
                        resistance_se = resistance_sum_mc$test$sigma,
                        resistance_p = resistance_sum_mc$test$pvalues,
                        likelihood_coefs = likelihood_sum_mc$test$coefficients,
                        likelihood_se = likelihood_sum_mc$test$sigma,
                        likelihood_p = likelihood_sum_mc$test$pvalues)

ate_table$var = rownames(ate_table)

# shape for plotting
ate_plot_table <-
  ate_table %>%
  reshape2::melt(id.vars = c("var")) %>%
  mutate(outcome = ifelse(grepl("likelihood", variable), "likelihood","resistance"),
         out_type = case_when(grepl("_se", variable) ~ "se",
                              grepl("_p", variable) ~ "p",
                              grepl("_coefs", variable) ~ "coef")) %>%
  dplyr::select(-variable) %>%
  pivot_wider(names_from = "out_type",
              values_from = "value") %>%
  mutate(var= gsub("vac_ex_mes_names", "", var)) %>%
  mutate(var = gsub('\\.', " ", var)) 

# get fitted value from intercept + coef
ate_plot_table$pred <- sapply(1:nrow(ate_plot_table), function(x){
  if(ate_plot_table$var[x] == "(Intercept)"){
    return(ate_plot_table$coef[x])
  }
  if(!ate_plot_table$var[x] == "(Intercept)" & ate_plot_table$outcome[x] == "resistance"){
    return(ate_plot_table$coef[ate_plot_table$var == "(Intercept)" & ate_plot_table$outcome == "resistance"] +
             ate_plot_table$coef[ate_plot_table$var == ate_plot_table$var[x] & ate_plot_table$outcome == "resistance"])
  }
  if(!ate_plot_table$var[x] == "(Intercept)" & ate_plot_table$outcome[x] == "likelihood"){
    return(ate_plot_table$coef[ate_plot_table$var == "(Intercept)" & ate_plot_table$outcome == "likelihood"] +
             ate_plot_table$coef[ate_plot_table$var == ate_plot_table$var[x] & ate_plot_table$outcome == "likelihood"])
  }
  
})

# uncertainty interval
ate_plot_table$lwr <- with(ate_plot_table, pred - 1.96*se)
ate_plot_table$upr <- with(ate_plot_table, pred + 1.96*se)

# aesthetics for plotting
ate_plot_table <- ate_plot_table %>% dplyr::select(-coef)
ate_plot_table$outcome <- factor(ate_plot_table$outcome, levels = c("resistance",
                                                                    "likelihood"))

ate_plot_table$var <- factor(ate_plot_table$var, levels = c("(Intercept)",
                                                            "Patriotism",
                                                            "People you know",
                                                            "Preventing harm",
                                                            "Physician recommend",
                                                            "Scientists recommend"))

# flag significance (using pval, not interval)
ate_plot_table$sig <- as.numeric(ate_plot_table$p <= .05)

# plot
ate_mc_plot_resistance <- 
  ggplot()+
  geom_pointrange(data = ate_plot_table %>% 
                    filter(!var == "(Intercept)" & outcome == "resistance"),
                  aes(x = factor(var, levels = c("Patriotism",
                                                 "People you know",
                                                 "Preventing harm",
                                                 "Physician recommend",
                                                 "Scientists recommend")),
                      y = pred, ymin = lwr,
                      ymax = upr, alpha = factor(sig)))+
  geom_ribbon(data = bind_rows(ate_plot_table %>% filter(var == "(Intercept)")  %>%
                                 mutate(x = "Physician recommend"),
                               ate_plot_table %>% filter(var == "(Intercept)")  %>%
                                 mutate(x = "Scientists recommend"),
                               ate_plot_table %>%filter(var == "(Intercept)")  %>%
                                 mutate(x = "People you know"),
                               ate_plot_table %>% filter(var == "(Intercept)")  %>%
                                 mutate(x = "Preventing harm"),
                               ate_plot_table %>% filter(var == "(Intercept)")  %>%
                                 mutate(x = "Patriotism")) %>%
                filter(outcome == "resistance"),
              aes(x = as.numeric(factor(x, levels = c("Patriotism",
                                                      "People you know",
                                                      "Preventing harm",
                                                      "Physician recommend",
                                                      "Scientists recommend"))),
                  ymin = lwr,
                  ymax =upr),
              alpha = .25)+
  geom_line(data = bind_rows(ate_plot_table %>% filter(var == "(Intercept)")  %>%
                               mutate(x = "Physician recommend"),
                             ate_plot_table %>% filter(var == "(Intercept)")  %>%
                               mutate(x = "Scientists recommend"),
                             ate_plot_table %>%filter(var == "(Intercept)")  %>%
                               mutate(x = "People you know"),
                             ate_plot_table %>% filter(var == "(Intercept)")  %>%
                               mutate(x = "Preventing harm"),
                             ate_plot_table %>% filter(var == "(Intercept)")  %>%
                               mutate(x = "Patriotism")) %>%
              filter(outcome == "resistance"),
            aes(x = as.numeric(factor(x, levels = c("Patriotism",
                                                    "People you know",
                                                    "Preventing harm",
                                                    "Physician recommend",
                                                    "Scientists recommend"))),
                y = pred), 
            lty = "dashed")+
  scale_alpha_manual(name = "", breaks = c(0,1),
                     values = c(.25, 1),
                     labels = c("ns","s"))+
  guides(alpha = FALSE)+
  scale_x_discrete(name = "Condition",
                   breaks = c("Patriotism",
                              "People you know",
                              "Preventing harm",
                              "Physician recommend",
                              "Scientists recommend"),
                   labels = c("Patriotism",
                              "People\nyou know",
                              "Preventing\nharm",
                              "Physician\nrecommend",
                              "Scientists\nrecommend"))+
  labs(y = "Vaccine Resistance (Binary Outcome, Proportion)", title = "Average Treatment Effects",
       subtitle = "Control estimate and 95% uncertainty interval shown with dashed line in shaded band\nEffects significant at p < .05, adjusted for multiple comparisons, darkened")+
  coord_flip() +
  theme_jg()+
  theme(text = element_text(size = 16))
ggsave(ate_mc_plot_resistance, file = "results/figures/ate_resistance.png", width = 10, height = 5)

ate_mc_plot_likelihood <- 
  ggplot()+
  geom_pointrange(data = ate_plot_table %>% 
                    filter(!var == "(Intercept)" & outcome == "likelihood"),
                  aes(x = factor(var, levels = c("Patriotism",
                                                 "People you know",
                                                 "Preventing harm",
                                                 "Physician recommend",
                                                 "Scientists recommend")),
                      y = pred, ymin = lwr,
                      ymax = upr, alpha = factor(sig)))+
  geom_ribbon(data = bind_rows(ate_plot_table %>% filter(var == "(Intercept)")  %>%
                                 mutate(x = "Physician recommend"),
                               ate_plot_table %>% filter(var == "(Intercept)")  %>%
                                 mutate(x = "Scientists recommend"),
                               ate_plot_table %>%filter(var == "(Intercept)")  %>%
                                 mutate(x = "People you know"),
                               ate_plot_table %>% filter(var == "(Intercept)")  %>%
                                 mutate(x = "Preventing harm"),
                               ate_plot_table %>% filter(var == "(Intercept)")  %>%
                                 mutate(x = "Patriotism")) %>%
                filter(outcome == "likelihood"),
              aes(x = as.numeric(factor(x, levels = c("Patriotism",
                                                      "People you know",
                                                      "Preventing harm",
                                                      "Physician recommend",
                                                      "Scientists recommend"))),
                  ymin = lwr,
                  ymax =upr),
              alpha = .25)+
  geom_line(data = bind_rows(ate_plot_table %>% filter(var == "(Intercept)")  %>%
                               mutate(x = "Physician recommend"),
                             ate_plot_table %>% filter(var == "(Intercept)")  %>%
                               mutate(x = "Scientists recommend"),
                             ate_plot_table %>%filter(var == "(Intercept)")  %>%
                               mutate(x = "People you know"),
                             ate_plot_table %>% filter(var == "(Intercept)")  %>%
                               mutate(x = "Preventing harm"),
                             ate_plot_table %>% filter(var == "(Intercept)")  %>%
                               mutate(x = "Patriotism")) %>%
              filter(outcome == "likelihood"),
            aes(x = as.numeric(factor(x, levels = c("Patriotism",
                                                    "People you know",
                                                    "Preventing harm",
                                                    "Physician recommend",
                                                    "Scientists recommend"))),
                y = pred), 
            lty = "dashed")+
  scale_alpha_manual(name = "", breaks = c(0,1),
                     values = c(.25, 1),
                     labels = c("ns","s"))+
  guides(alpha = FALSE)+
  scale_x_discrete(name = "Condition",
                   breaks = c("Patriotism",
                              "People you know",
                              "Preventing harm",
                              "Physician recommend",
                              "Scientists recommend"),
                   labels = c("Patriotism",
                              "People\nyou know",
                              "Preventing\nharm",
                              "Physician\nrecommend",
                              "Scientists\nrecommend"))+
  labs(y = "Vaccine Likelihood (1-7 Scale, Average)", title = "Average Treatment Effects",
       subtitle = "Control estimate and 95% uncertainty interval shown with dashed line in shaded band\nEffects significant at p < .05, adjusted for multiple comparisons, darkened")+
  coord_flip() +
  theme_jg()+
  theme(text = element_text(size = 16))
ggsave(ate_mc_plot_likelihood, file = "results/figures/ate_likelihood.png", width = 10, height = 5)

# estimate differences between treatment fx
diff_in_fx_resistance <- glht(resistance_comparisons, linfct = c("vac_ex_mes_namesPhysician.recommend - vac_ex_mes_namesPatriotism  = 0", 
                                                                 "vac_ex_mes_namesPhysician.recommend - vac_ex_mes_namesPreventing.harm = 0",
                                                                 "vac_ex_mes_namesPhysician.recommend - vac_ex_mes_namesPeople.you.know = 0",
                                                                 "vac_ex_mes_namesPhysician.recommend - vac_ex_mes_namesScientists.recommend = 0",
                                                                 "vac_ex_mes_namesScientists.recommend - vac_ex_mes_namesPreventing.harm  = 0",
                                                                 "vac_ex_mes_namesScientists.recommend - vac_ex_mes_namesPeople.you.know  = 0",
                                                                 "vac_ex_mes_namesScientists.recommend - vac_ex_mes_namesPatriotism  = 0",
                                                                 "vac_ex_mes_namesPeople.you.know - vac_ex_mes_namesPreventing.harm  = 0",
                                                                 "vac_ex_mes_namesPeople.you.know - vac_ex_mes_namesPatriotism  = 0",
                                                                 "vac_ex_mes_namesPreventing.harm - vac_ex_mes_namesPatriotism  = 0"))



diff_in_fx_likelihood <- glht(likelihood_comparisons, linfct = c("vac_ex_mes_namesPhysician.recommend - vac_ex_mes_namesPatriotism  = 0", 
                                                                 "vac_ex_mes_namesPhysician.recommend - vac_ex_mes_namesPreventing.harm = 0",
                                                                 "vac_ex_mes_namesPhysician.recommend - vac_ex_mes_namesPeople.you.know = 0",
                                                                 "vac_ex_mes_namesPhysician.recommend - vac_ex_mes_namesScientists.recommend = 0",
                                                                 "vac_ex_mes_namesScientists.recommend - vac_ex_mes_namesPreventing.harm  = 0",
                                                                 "vac_ex_mes_namesScientists.recommend - vac_ex_mes_namesPeople.you.know  = 0",
                                                                 "vac_ex_mes_namesScientists.recommend - vac_ex_mes_namesPatriotism  = 0",
                                                                 "vac_ex_mes_namesPeople.you.know - vac_ex_mes_namesPreventing.harm  = 0",
                                                                 "vac_ex_mes_namesPeople.you.know - vac_ex_mes_namesPatriotism  = 0",
                                                                 "vac_ex_mes_namesPreventing.harm - vac_ex_mes_namesPatriotism  = 0"))

s_likelihood <- summary(diff_in_fx_likelihood)
s_resistance <- summary(diff_in_fx_resistance)

# shape for plotting
fx_comparisons_table <- 
  data.frame(coef_likelihood = s_likelihood$test$coefficients,
             p_likelihood = s_likelihood$test$pvalues) %>%
  mutate(var = names(s_likelihood$test$coefficients),
         sig_likelihood = as.numeric(s_likelihood$test$pvalues < .05)) %>%
  left_join(data.frame(coef_resistance = s_resistance$test$coefficients,
                       p_resistance = s_resistance$test$pvalues) %>%
              mutate(var = names(s_resistance$test$coefficients),
                     sig_resistance = as.numeric(s_resistance$test$pvalues < .05)),
            by = "var") %>%
  dplyr::select(var, coef_likelihood, coef_resistance, 
                p_likelihood, p_resistance,
                sig_likelihood, sig_resistance) %>%
  mutate(var = gsub("vac_ex_mes_names", "", var))

# aesthetics
vars <- strsplit(fx_comparisons_table$var, split = " - ")

fx_comparisons_table$var1 <- sapply(vars, function(x){
  gsub("."," ",x[1],fixed = TRUE)
})

fx_comparisons_table$var2 <- sapply(vars, function(x){
  gsub("."," ",x[2],fixed = TRUE)
})

# plot
likelihood_comparisons <- 
  fx_comparisons_table %>%
  mutate(var1 = factor(var1, levels = c("Physician recommend","Scientists recommend",
                                        "People you know","Preventing harm",
                                        "Patriotism")),
         var2 = factor(var2, levels = c("Physician recommend","Scientists recommend",
                                        "People you know","Preventing harm",
                                        "Patriotism"))) %>%
  ggplot(aes(x = var2, y = var1))+
  geom_text(aes(label = paste0(round(coef_likelihood, 4),"\n(", round(p_likelihood, 4), ")"),
                alpha = factor(sig_likelihood)),
            size = 3)+
  #geom_tile(aes(alpha = abs(square)), fill = "black")+
  #geom_circle(aes(alpha = abs(circle)), fill = "black")+
  scale_alpha_manual(name = "",
                     breaks = c(0,1),
                     values = c(.4, 1),
                     labels = c("ns","s"))+
  scale_x_discrete(name = "Comparison Condition",
                   breaks = c("Patriotism",
                              "People you know",
                              "Preventing harm",
                              "Scientists recommend"),
                   labels = c("Patriotism",
                              "People\nyou know",
                              "Preventing\nharm",
                              "Scientists\nrecommend"))+
  scale_y_discrete(name = "Primary Condition",
                   breaks = c("People you know",
                              "Preventing harm",
                              "Scientists recommend",
                              "Physician recommend"),
                   labels = c("People\nyou know",
                              "Preventing\nharm",
                              "Scientists\nrecommend",
                              "Physician\nrecommend"))+
  labs(title = "Differences in Treatment Effects Between Conditions",
       subtitle = "Differences in average likelihood of taking COVID vaccine (1-7 scale)\nDifferences significant at p < .05 (adjusted for multiple comparisons) highlighted")+
  guides(alpha = FALSE)+
  theme_jg()+
  theme(panel.grid.major = element_blank(),
        panel.grid.minor = element_blank())

resistance_comparisons <- 
  fx_comparisons_table %>%
  mutate(var1 = factor(var1, levels = c("Physician recommend","Scientists recommend",
                                        "People you know","Preventing harm",
                                        "Patriotism")),
         var2 = factor(var2, levels = c("Physician recommend","Scientists recommend",
                                        "People you know","Preventing harm",
                                        "Patriotism"))) %>%
  ggplot(aes(x = var2, y = var1))+
  geom_text(aes(label = paste0(round(coef_resistance, 4),"\n(", round(p_resistance, 4), ")"),
                alpha = factor(sig_resistance)),
            size = 3)+
  #geom_tile(aes(alpha = abs(square)), fill = "black")+
  #geom_circle(aes(alpha = abs(circle)), fill = "black")+
  scale_alpha_manual(name = "",
                     breaks = c(0,1),
                     values = c(.4, 1),
                     labels = c("ns","s"))+
  scale_x_discrete(name = "Comparison Condition",
                   breaks = c("Patriotism",
                              "People you know",
                              "Preventing harm",
                              "Scientists recommend"),
                   labels = c("Patriotism",
                              "People\nyou know",
                              "Preventing\nharm",
                              "Scientists\nrecommend"))+
  scale_y_discrete(name = "Primary Condition",
                   breaks = c("People you know",
                              "Preventing harm",
                              "Scientists recommend",
                              "Physician recommend"),
                   labels = c("People\nyou know",
                              "Preventing\nharm",
                              "Scientists\nrecommend",
                              "Physician\nrecommend"))+
  labs(title = "Differences in Treatment Effects Between Conditions",
       subtitle = "Differences in proportion 'extremely unlikely' to take COVID vaccine\nDifferences significant at p < .05 (adjusted for multiple comparisons) highlighted")+
  guides(alpha = FALSE)+
  theme_jg()+
  theme(panel.grid.major = element_blank(),
        panel.grid.minor = element_blank())
ggsave(resistance_comparisons, file = "results/figures/comparison_resistance.png", width = 9, height = 4)
ggsave(likelihood_comparisons, file = "results/figures/comparison_likelihood.png", width = 9, height = 4)

### grf results

# load up output
files <- list.files(path = "output/main_output/",
                    pattern = "message*")

for(l in files){
  load(paste0("output/main_output/",l))
}

# join fits to data
wave14.c.exp.message <- 
  as_tibble(wave14.c.exp.oh) %>%
  left_join(message_harm_binary$predictions %>%
              dplyr::select(1:3) %>%
              rename(predictions_Control_Preventing.harm_binary = predictions_Control_Preventing.harm,
                     variance.estimates_Control_Preventing.harm_binary = variance.estimates_Control_Preventing.harm),
            by = "psid") %>%
  left_join(message_harm_continuous$predictions %>%
              dplyr::select(1:3) %>%
              rename(predictions_Control_Preventing.harm_continuous = predictions_Control_Preventing.harm,
                     variance.estimates_Control_Preventing.harm_continuous = variance.estimates_Control_Preventing.harm),
            by = "psid") %>%
  left_join(message_patriotism_binary$predictions %>%
              dplyr::select(1:3) %>%
              rename(predictions_Control_Patriotism_binary = predictions_Control_Patriotism,
                     variance.estimates_Control_Patriotism_binary = variance.estimates_Control_Patriotism),
            by = "psid") %>%
  left_join(message_patriotism_continuous$predictions %>%
              dplyr::select(1:3) %>%
              rename(predictions_Control_Patriotism_continuous = predictions_Control_Patriotism,
                     variance.estimates_Control_Patriotism_continuous = variance.estimates_Control_Patriotism),
            by = "psid") %>%
  left_join(message_people_binary$predictions %>%
              dplyr::select(1:3) %>%
              rename(predictions_Control_People.you.know_binary = predictions_Control_People.you.know,
                     variance.estimates_Control_People.you.know_binary = variance.estimates_Control_People.you.know),
            by = "psid") %>%
  left_join(message_people_continuous$predictions %>%
              dplyr::select(1:3) %>%
              rename(predictions_Control_People.you.know_continuous = predictions_Control_People.you.know,
                     variance.estimates_Control_People.you.know_continuous = variance.estimates_Control_People.you.know),
            by = "psid") %>%
  left_join(message_physician_binary$predictions %>%
              dplyr::select(1:3) %>%
              rename(predictions_Control_Physician.recommend_binary = predictions_Control_Physician.recommend,
                     variance.estimates_Control_Physician.recommend_binary = variance.estimates_Control_Physician.recommend),
            by = "psid") %>%
  left_join(message_physician_continuous$predictions %>%
              dplyr::select(1:3) %>%
              rename(predictions_Control_Physician.recommend_continuous = predictions_Control_Physician.recommend,
                     variance.estimates_Control_Physician.recommend_continuous = variance.estimates_Control_Physician.recommend),
            by = "psid") %>%
  left_join(message_scientists_binary$predictions %>%
              dplyr::select(1:3) %>%
              rename(predictions_Control_Scientists.recommend_binary = predictions_Control_Scientists.recommend,
                     variance.estimates_Control_Scientists.recommend_binary = variance.estimates_Control_Scientists.recommend),
            by = "psid") %>%
  left_join(message_scientists_continuous$predictions %>%
              dplyr::select(1:3) %>%
              rename(predictions_Control_Scientists.recommend_continuous = predictions_Control_Scientists.recommend,
                     variance.estimates_Control_Scientists.recommend_continuous = variance.estimates_Control_Scientists.recommend),
            by = "psid")

# write summary table
bind_rows(prop.ind.sig(message_harm_binary, comparison = c("Control","Preventing harm")) %>%
            mutate(outcome = "Resistance"),
          prop.ind.sig(message_patriotism_binary, comparison = c("Control","Patriotism"))%>%
            mutate(outcome = "Resistance"),
          prop.ind.sig(message_scientists_binary, comparison = c("Control","Scientists recommend"))%>%
            mutate(outcome = "Resistance"),
          prop.ind.sig(message_physician_binary, comparison = c("Control","Physician recommend"))%>%
            mutate(outcome = "Resistance"),
          prop.ind.sig(message_people_binary, comparison = c("Control","People you know"))%>%
            mutate(outcome = "Resistance"),
          
          prop.ind.sig(message_harm_continuous, comparison = c("Control","Preventing harm"))%>%
            mutate(outcome = "Likelihood"),
          prop.ind.sig(message_patriotism_continuous, comparison = c("Control","Patriotism"))%>%
            mutate(outcome = "Likelihood"),
          prop.ind.sig(message_scientists_continuous, comparison = c("Control","Scientists recommend"))%>%
            mutate(outcome = "Likelihood"),
          prop.ind.sig(message_physician_continuous, comparison = c("Control","Physician recommend"))%>%
            mutate(outcome = "Likelihood"),
          prop.ind.sig(message_people_continuous, comparison = c("Control","People you know"))%>%
            mutate(outcome = "Likelihood")) %>%
  mutate(comparison = gsub("Control_", "", comparison)) %>%
  mutate_at(.vars = dplyr::vars(c("neg.sig","pos.sig","null","above.avg","below.avg")),
            .funs = function(x){
              round(x, 3)
            }) %>%
  write.csv(file = "results/tables/prop_ind_sig_messages.csv")

### variable importances

imps <- data.frame(var = unique(message_harm_continuous$importances$var))
imps$label = c("Survey Date","Federal Government Underreacted",
               "COVID Concern: Self","COVID Concern: Family",
               "Perceived Case Trend Increase",
             #  "COVID News/Info from Facebook",
               "Behavior: Avoid Contact",
               "Behavior: Avoid Crowds",
               "Behavior: Wash Hands",
               "Behavior: Wear Mask",
               "Race: White","Race: Black",
               "Race: Latino","Race: Asian",
               "Race: Other Race",
               "Children <18 in Household",
               "College","Age","Gender",
               "USR: Rural","USR: Suburban","USR: Urban",
               "Political Interest","Party ID","Ideology",
               "Household Income",
               "Region: Midwest", "Region: Northeast","Region: South","Region: West",
               "COVID Diagnosed: Self","COVID Suspected","COVID Diagnosed: Family",
               "County Cases per 1000 (Cumulative)","County Cases per 1000 (New)",
               "30 Day New Case Rate","30 Day New Death Rate")

# make plots
imps_resistance <- 
  imps %>% 
  left_join(message_harm_binary$importances %>%
              arrange(desc(imp_Control_Preventing.harm)),
            by = "var") %>%
  left_join(message_people_binary$importances %>%
              arrange(desc(imp_Control_People.you.know)),
            by = "var") %>%
  left_join(message_physician_binary$importances %>%
              arrange(desc(imp_Control_Physician.recommend)),
            by = "var") %>%
  left_join(message_scientists_binary$importances %>%
              arrange(desc(imp_Control_Scientists.recommend)),
            by = "var") %>%
  left_join(message_patriotism_binary$importances %>%
              arrange(desc(imp_Control_Patriotism)),
            by = 'var') %>%
  reshape2::melt(id.vars = c("var","label")) %>%
  mutate(value = ifelse(is.na(value), 0, value),
         variable = gsub("imp_Control_", "", variable)) %>%
  group_by(var) %>%
  mutate(position = rank(value)) %>%
  filter(value >= .1) %>%
  ungroup() %>%
  arrange(desc(value)) %>%
  ggplot(aes(x = fct_rev(fct_inorder(label)),
             y = value, fill = variable, group = factor(position)))+
  geom_bar(position = position_dodge(preserve = "single"),
           stat = "identity")+
  scale_fill_OkabeIto(name = "Condition",
                      breaks = c("Preventing.harm","People.you.know",
                                 "Physician.recommend","Scientists.recommend",
                                 "Patriotism"),
                      labels =c("Preventing Harm","People You Know",
                                "Physician Recommend","Scientists Recommend",
                                "Patriotism"))+
  labs(y = "Variable Importance",
       title = "A. Resistance")+
  scale_x_discrete(name = "")+
  coord_flip()+
  theme_jg()


imps_likelihood <- 
  imps %>%
  left_join(message_harm_continuous$importances %>%
              arrange(desc(imp_Control_Preventing.harm)),
            by = "var") %>%
  left_join(message_people_continuous$importances %>%
              arrange(desc(imp_Control_People.you.know)),
            by = "var") %>%
  left_join(message_physician_continuous$importances %>%
              arrange(desc(imp_Control_Physician.recommend)),
            by = "var") %>%
  left_join(message_scientists_continuous$importances %>%
              arrange(desc(imp_Control_Scientists.recommend)),
            by = "var") %>%
  left_join(message_patriotism_continuous$importances %>%
              arrange(desc(imp_Control_Patriotism)),
            by = 'var') %>%
  reshape2::melt(id.vars = c("var","label")) %>%
  mutate(value = ifelse(is.na(value), 0, value),
         variable = gsub("imp_Control_", "", variable)) %>%
  group_by(var) %>%
  mutate(position = rank(value)) %>%
  filter(value >= .1) %>%
  ungroup() %>%
  arrange(desc(value)) %>%
  ggplot(aes(x = fct_rev(fct_inorder(label)),
             y = value, fill = variable, group = factor(position)))+
  geom_bar(position = position_dodge(preserve = "single"),
           stat = "identity")+
  scale_fill_OkabeIto(name = "Condition",
                      breaks = c("Preventing.harm","People.you.know",
                                 "Physician.recommend","Scientists.recommend",
                                 "Patriotism"),
                      labels =c("Preventing Harm","People You Know",
                                "Physician Recommend","Scientists Recommend",
                                "Patriotism"))+
  labs(y = "Variable Importance",
       title = "B. Likelihood")+
  scale_x_discrete(name = "")+
  coord_flip()+
  theme_jg()

# organize plots
title <- ggdraw() +
  draw_label_theme("Variable Importance", 
                   theme = theme_bw()+
                     theme(text=element_text(family="serif", size=24),
                           plot.title = element_text(face = "bold", size = 30),
                           plot.subtitle = element_text(size = 20),
                           plot.caption = element_text(size = 16)),
                   element = "plot.title",
                   x = 0.05, hjust = 0, vjust = 1)
subtitle <- ggdraw() +
  draw_label_theme("Bars reflect importance of variables with importance > .1 in given condition\nImportance represents weighted sum of how often feature was used for splitting",
                   theme = theme_bw()+
                     theme(text=element_text(family="serif", size=24),
                           plot.title = element_text(face = "bold", size = 30),
                           plot.subtitle = element_text(size = 20),
                           plot.caption = element_text(size = 16)),
                   element = "plot.subtitle",
                   x = 0.05, hjust = 0, vjust = 1)

imps_draw <- cowplot::plot_grid(imps_resistance+ylim(c(0,.25)), 
                                imps_likelihood+ylim(c(0,.25)),
                                nrow = 2, ncol = 1, widths = c(1, 1),
                                align = "v",
                                axis = "t")
# plot
imps_out <- 
  plot_grid(title, subtitle, 
            imps_draw,
            ncol = 1, 
            rel_heights = c(0.02, 0.15, .8))
ggsave(imps_out, file = "results/figures/imps_out.png", width = 12, height = 8)

# check patriotism fx by ideology
patriotism_by_ideo_resistance <- 
  message_patriotism_binary$predictions %>%
  left_join(as_tibble(wave14.c.exp.message) %>% 
              dplyr::select(psid, ideology),
            by = "psid") %>%
  filter(inout == "crosstrain_Control_Patriotism") %>%
  ggplot(aes(x = factor(ideology),
             y = predictions_Control_Patriotism))+
  geom_hline(yintercept = 0, lty = "dashed", alpha = .5)+
  geom_boxplot()+
  scale_x_discrete(name = "Ideological Identity",
                   breaks = c(1,2,3,4,5,6,7),
                   labels = c("Extremely Liberal","Liberal","Slightly Liberal",
                              "Moderate","Slightly Conservative",
                              "Conservative","Extremely Conservative"))+
  labs(y = "Predicted Effect",
       title = "Predicted Effects of Patriotism Message by Ideological Identity",
       subtitle = "Change in probability of respondent saying they are 'extremely unlikely' to take COVID-19 vaccine")+
  theme_jg()
ggsave(patriotism_by_ideo_resistance, file = "results/figures/patriotism_by_ideo_resistance.png", width = 12, height = 5)


# plot rank against value

# check re: ceiling effect by income
ceiling_fx_control <- 
  wave14.c.exp.message %>%
  mutate(hhi_group = factor(case_when(houseHoldIncome <= 22499 ~ "First Quartile",
                                      houseHoldIncome %in% c(22500:50000) ~ "Second Quartile",
                                      houseHoldIncome %in% c(50001:94999) ~ "Third Quartile",
                                      houseHoldIncome >= 95000 ~ "Fourth Quartile"),
                            levels = c("First Quartile","Second Quartile","Third Quartile","Fourth Quartile"))) %>%
  filter(vac_ex_mes_trunc == "Control" & !is.na(hhi_group) & !is.na(vac_message)) %>%
  group_by(hhi_group) %>%
  summarise(average_resistance = mean(vac_message == 1),
            average_likelihood = mean(vac_message)) %>%
  reshape2::melt(id.vars = c("hhi_group")) %>%
  filter(variable == "average_resistance") %>%
  ggplot(aes(x = fct_inorder(hhi_group),
             y = value))+
  #facet_wrap(~variable, nrow = 1, scales = "free",
  #           labeller = as_labeller(c("average_resistance" = "Average Resistance (0-1) Scale)",
  #                                    "average_likelihood" = "Average Likelihood (1-7 Scale)")))+
  geom_bar(stat = "identity")+
  scale_x_discrete(name = "Household Income Quartile",
                   breaks = c("First Quartile","Second Quartile",
                              "Third Quartile", "Fourth Quartile"),
                   labels = c(1,2,3,4))+
  labs(y = "Average Resistance (0-1 Scale)", 
       title = "Vaccine Resistance by Household Income",
       subtitle = "Control group respondents")+
  theme_jg()+
  theme(text = element_text(size = 18, family = "serif"),
        plot.title = element_text(face = "bold", size = 28),
        plot.subtitle = element_text(size = 20))

### MAKE FIGURE 3, COMBINING IMPORTANCE WITH IND FX
imps_resistance_f3 <- 
  imps %>% 
  left_join(message_harm_binary$importances %>%
              arrange(desc(imp_Control_Preventing.harm)),
            by = "var") %>%
  left_join(message_people_binary$importances %>%
              arrange(desc(imp_Control_People.you.know)),
            by = "var") %>%
  left_join(message_physician_binary$importances %>%
              arrange(desc(imp_Control_Physician.recommend)),
            by = "var") %>%
  left_join(message_scientists_binary$importances %>%
              arrange(desc(imp_Control_Scientists.recommend)),
            by = "var") %>%
  left_join(message_patriotism_binary$importances %>%
              arrange(desc(imp_Control_Patriotism)),
            by = 'var') %>%
  reshape2::melt(id.vars = c("var","label")) %>%
  mutate(value = ifelse(is.na(value), 0, value),
         variable = gsub("imp_Control_", "", variable)) %>%
  group_by(var) %>%
  mutate(position = rank(value)) %>%
  filter(value >= .1) %>%
  ungroup() %>%
  arrange(desc(value)) %>%
  ggplot(aes(x = fct_rev(fct_inorder(label)),
             y = value, fill = variable, group = factor(position)))+
  geom_bar(position = position_dodge(preserve = "single"),
           stat = "identity")+
  scale_fill_OkabeIto(name = "Condition",
                      breaks = c("Preventing.harm","People.you.know",
                                 "Physician.recommend","Scientists.recommend",
                                 "Patriotism"),
                      labels =c("Preventing Harm","People You Know",
                                "Physician Recommend","Scientists Recommend",
                                "Patriotism"))+
  labs(y = "Variable Importance",
       title = "Variable Importance, Resistance",
       subtitle = "Importance represents weighted splitting frequency\nImportances > .1 shown")+
  scale_x_discrete(name = "",
                   breaks = c("Household Income","Political Interest",
                              "County Cases per 1000 (Cumulative)",
                              "Ideology","Age",
                              "County Cases per 1000 (New)",
                              "Survey Date"),
                   labels = c("Household\nIncome","Political\nInterest",
                              "County Cases/1000\n(Cumulative)",
                              "Ideology","Age",
                              "County Cases/1000\n(New)",
                              "Survey Date"))+
  guides(fill=guide_legend(nrow=3,byrow=TRUE))+
  coord_flip()+
  theme_jg()+
  theme(text = element_text(size = 18, family = "serif"),
        plot.title = element_text(face = "bold", size = 28),
        plot.subtitle = element_text(size = 20),
        plot.caption = element_text(hjust = 0, size = 18),
        legend.position = "bottom",
        legend.margin = margin(),
        legend.title.align = 0.5,
        legend.spacing.y = unit(.1, "cm"),
        legend.direction = "vertical")

imps_likelihood_f3 <- 
  imps %>% 
  left_join(message_harm_continuous$importances %>%
              arrange(desc(imp_Control_Preventing.harm)),
            by = "var") %>%
  left_join(message_people_continuous$importances %>%
              arrange(desc(imp_Control_People.you.know)),
            by = "var") %>%
  left_join(message_physician_continuous$importances %>%
              arrange(desc(imp_Control_Physician.recommend)),
            by = "var") %>%
  left_join(message_scientists_continuous$importances %>%
              arrange(desc(imp_Control_Scientists.recommend)),
            by = "var") %>%
  left_join(message_patriotism_continuous$importances %>%
              arrange(desc(imp_Control_Patriotism)),
            by = 'var') %>%
  reshape2::melt(id.vars = c("var","label")) %>%
  mutate(value = ifelse(is.na(value), 0, value),
         variable = gsub("imp_Control_", "", variable)) %>%
  group_by(var) %>%
  mutate(position = rank(value)) %>%
  filter(value >= .1) %>%
  ungroup() %>%
  arrange(desc(value)) %>%
  ggplot(aes(x = fct_rev(fct_inorder(label)),
             y = value, fill = variable, group = factor(position)))+
  geom_bar(position = position_dodge(preserve = "single"),
           stat = "identity")+
  scale_fill_OkabeIto(name = "Condition",
                      breaks = c("Preventing.harm","People.you.know",
                                 "Physician.recommend","Scientists.recommend",
                                 "Patriotism"),
                      labels =c("Preventing Harm","People You Know",
                                "Physician Recommend","Scientists Recommend",
                                "Patriotism"))+
  labs(y = "Variable Importance",
       title = "Variable Importance, Likelihood",
       subtitle = "Importance represents weighted splitting frequency\nImportances > .1 shown")+
  scale_x_discrete(name = "",
                   breaks = c("Household Income","Political Interest",
                              "County Cases per 1000 (Cumulative)",
                              "Ideology","Age",
                              "County Cases per 1000 (New)",
                              "Survey Date"),
                   labels = c("Household\nIncome","Political\nInterest",
                              "County Cases/1000\n(Cumulative)",
                              "Ideology","Age",
                              "County Cases/1000\n(New)",
                              "Survey Date"))+
  guides(fill=guide_legend(nrow=3,byrow=TRUE))+
  coord_flip()+
  theme_jg()+
  theme(text = element_text(size = 18, family = "serif"),
        plot.title = element_text(face = "bold", size = 28),
        plot.subtitle = element_text(size = 20),
        plot.caption = element_text(hjust = 0, size = 18),
        legend.position = "bottom",
        legend.margin = margin(),
        legend.title.align = 0.5,
        legend.spacing.y = unit(.1, "cm"),
        legend.direction = "vertical")

scientists_likelihood_rank_hhiQuartile_f3 <-
  message_scientists_continuous$predictions %>%
  left_join(as_tibble(wave14.c.exp.message) %>% dplyr::select(psid, houseHoldIncome),
            by = "psid") %>%
  mutate(hhi_group = factor(case_when(houseHoldIncome <= 22499 ~ "First Quartile",
                                      houseHoldIncome %in% c(22500:50000) ~ "Second Quartile",
                                      houseHoldIncome %in% c(50001:94999) ~ "Third Quartile",
                                      houseHoldIncome >= 95000 ~ "Fourth Quartile"),
                            levels = c("First Quartile","Second Quartile","Third Quartile","Fourth Quartile"))) %>%
  filter(inout == "crosstrain_Control_Scientists.recommend") %>%
  ggplot(aes(x = rank(predictions_Control_Scientists.recommend), 
             y = predictions_Control_Scientists.recommend,
             col = factor(hhi_group)))+
  geom_pointrange(aes(ymin = predictions_Control_Scientists.recommend - 1.96*sqrt(variance.estimates_Control_Scientists.recommend),
                      ymax = predictions_Control_Scientists.recommend + 1.96*sqrt(variance.estimates_Control_Scientists.recommend)),
                  col = "grey", alpha = .3)+
  geom_point(size = 2)+
  geom_hline(yintercept = 0)+
  geom_hline(yintercept = grf::average_treatment_effect(message_scientists_continuous$forest)[1],
             lty = "dashed")+
  coord_flip()+
  guides(col=guide_legend(nrow=2,byrow=TRUE,
                          override.aes = list(size = 3)))+
  scale_color_OkabeIto(name = "Household Income")+
  labs(x = "Prediction Rank",
       y = "Predicted Change vs. Control (1-7 Scale, Average)",
       title = "Sorted Predicted Effects on Likelihood, Scientists",
       subtitle = "Predictions and 95% intervals\nAmong respondents in control or scientists condition",
       caption = "Note: Household income included as a continuous measure in modeling\nRecoded to quartiles for visualization")+
  theme_jg()+
  theme(text = element_text(size = 18, family = "serif"),
        plot.title = element_text(face = "bold", size = 28),
        plot.subtitle = element_text(size = 20),
        plot.caption = element_text(hjust = 0, size = 18),
        legend.position = "bottom",
        legend.direction = "vertical",
        legend.title.align = 0.5,
        legend.margin = margin(),
        legend.key.size = unit(.1, "cm"))


physician_resistance_rank_hhiQuartile_f3 <-
  message_physician_binary$predictions %>%
  left_join(as_tibble(wave14.c.exp.message) %>% dplyr::select(psid, houseHoldIncome),
            by = "psid") %>%
  mutate(hhi_group = factor(case_when(houseHoldIncome <= 22499 ~ "First Quartile",
                                      houseHoldIncome %in% c(22500:50000) ~ "Second Quartile",
                                      houseHoldIncome %in% c(50001:94999) ~ "Third Quartile",
                                      houseHoldIncome >= 95000 ~ "Fourth Quartile"),
                            levels = c("First Quartile","Second Quartile","Third Quartile","Fourth Quartile"))) %>%
  filter(inout == "crosstrain_Control_Physician.recommend") %>%
  ggplot(aes(x = rank(predictions_Control_Physician.recommend), 
             y = predictions_Control_Physician.recommend,
             col = factor(hhi_group)))+
  geom_pointrange(aes(ymin = predictions_Control_Physician.recommend - 1.96*sqrt(variance.estimates_Control_Physician.recommend),
                      ymax = predictions_Control_Physician.recommend + 1.96*sqrt(variance.estimates_Control_Physician.recommend)),
                  col = "grey", alpha = .3)+
  geom_point(size = 2)+
  geom_hline(yintercept = 0)+
  geom_hline(yintercept = grf::average_treatment_effect(message_physician_binary$forest)[1],
             lty = "dashed")+
  coord_flip()+
  guides(col=guide_legend(nrow=2,byrow=TRUE,
                          override.aes = list(size = 3)))+
  scale_color_OkabeIto(name = "Household Income")+
  labs(x = "Prediction Rank",
       y = "Predicted Change vs. Control",
       title = "Predicted Effects on Probability of Vaccine Resistance\nPhysician Condition by Household Income",
       subtitle = "Predictions and 95% intervals\nAmong respondents in control or physician condition",
       caption = "Note: Household income included as a continuous measure in modeling\nRecoded to quartiles for visualization")+
  theme_jg()+
  theme(text = element_text(size = 18, family = "serif"),
        plot.title = element_text(face = "bold", size = 28),
        plot.subtitle = element_text(size = 20),
        plot.caption = element_text(hjust = 0, size = 18),
        legend.position = "bottom",
        legend.direction = "vertical",
        legend.title.align = 0.5,
        legend.margin = margin(),
        legend.key.size = unit(.1, "cm"))


scientists_likelihood_hhi_quartile <- 
  plot.var.quantile.basic(message_scientists_continuous,
                          effect_var = "houseHoldIncome",
                          n_tiles = 4,
                          comparison = c("Control","Scientists.recommend"))+
  labs(x = "Household Income Quartile",
       y = "Group Average Treatment Effect\n(1-7 Scale)",
       title = "Group Average Treatment Effects\nScienstist Condition by Household Income",
       subtitle = "Average treatment effect with 95% uncertainty interval\nShown with dashed line in shaded band")+
  theme_jg()+
  theme(text = element_text(size = 18, family = "serif"),
        plot.title = element_text(face = "bold", size = 28),
        plot.subtitle = element_text(size = 20))

control_resistance_hhi <- 
  wave14.c.exp.message %>%
  mutate(hhi_group = case_when(houseHoldIncome <= 22499 ~ 1,
                               houseHoldIncome %in% c(22500:50000) ~ 2,
                               houseHoldIncome %in% c(50001:94999) ~ 3,
                               houseHoldIncome >= 95000 ~ 4)) %>%
  filter(vac_ex_mes_trunc == "Control" & !is.na(hhi_group) & !is.na(vac_message)) %>%
  group_by(hhi_group) %>%
  summarise(average_resistance = mean(vac_message == 1),
            sd_resistance = sd(vac_message == 1),
            n = n()) %>%
  mutate(se = sd_resistance / sqrt(n)) %>%
  mutate(lwr = average_resistance - 1.96*se,
         upr = average_resistance + 1.96*se)

physician_resistance_hhi_quartile <- 
  plot.var.quantile.basic(message_physician_binary,
                          effect_var = "houseHoldIncome",
                          n_tiles = 4,
                          comparison = c("Control","Physician.recommend"))+
  geom_text(data = control_resistance_hhi %>%
              rename(quantile = hhi_group), 
            aes(x = quantile, y = 0.02, label = round(average_resistance, 2)),
            family= "serif", size = 8)+
  labs(x = "Household Income Quartile",
       y = "Group Average Treatment Effect\n(Binary Outcome, Proportion)",
       title = "Group Average Treatment Effects, Resistance\nPhysician Condition by Household Income",
       subtitle = "ATE with 95% uncertainty interval shown with dashed line in shaded band\nProportion resistant in control condition labeled")+
  theme_jg()+
  theme(text = element_text(size = 20, family = "serif"),
        plot.title = element_text(face = "bold", size = 28),
        plot.subtitle = element_text(size = 20))

title <- ggdraw() +
  draw_label_theme("Testing for Effect Heterogeneity", 
                   theme = theme_bw()+
                     theme(text=element_text(family="serif", size=24),
                           plot.title = element_text(face = "bold", size = 35),
                           plot.subtitle = element_text(size = 20),
                           plot.caption = element_text(size = 16)),
                   element = "plot.title",
                   x = 0.05, hjust = 0)
subtitle <- ggdraw() +
  draw_label_theme("Top left panel shows most important variables for vaccine likelihood outcome\nTop right panel shows variation in effects at highest-importance variable x condition\nBottom panel shows corresponding group average treatment effects by household income quartile",
                   theme = theme_bw()+
                     theme(text=element_text(family="serif", size=24),
                           plot.title = element_text(size = 20),
                           plot.subtitle = element_text(size = 20),
                           plot.caption = element_text(size = 16)),
                   element = "plot.title",
                   x = 0.05, hjust = 0)

f3_draw_resistance <- cowplot::plot_grid(imps_resistance_f3, 
                                         physician_resistance_rank_hhiQuartile_f3,
                                         physician_resistance_hhi_quartile,
                                         labels = c("A.","B.","C."),
                                         label_fontfamily = "serif",
                                         label_size = 30,
                                         rel_widths = c(1,1,2),
                                         #rel_heights = c(.3, .3, .2, .2),
                                         ncol = 2)

# plot
f3_out <- 
  plot_grid(title,
           # subtitle,
            f3_draw_resistance,
            ncol = 1, 
            rel_heights = c(0.05, 
                            0.95))
ggsave(f3_out, file = "results/figures/f3_draw_resistance.png", width = 20, height = 20)
